from gurobipy import *
import numpy as np
from math import sqrt
import torch
import time
from scipy.optimize import LinearConstraint, NonlinearConstraint
from scipy.optimize import minimize
from math import isnan

def prepare_weights(arr, num_clusters):
    weights = np.zeros(num_clusters)
    for i in arr: 
        weights[int(i)] += 1
    return weights


def simplex_projection(y):
    u = y[np.argsort(-y)]
    d = y.shape[0]
    rho_idx = -1
    cusum = 0
    for i in range(d):
        cusum += u[i]
        temp = u[i] + (1 - cusum) / (i+1)
        if(temp > 0):
            rho_idx = i
    l = (1 - torch.sum(u[:rho_idx+1])) / (rho_idx + 1)
    dt = (y + l).dtype
    return torch.max(y + l,torch.zeros(d, dtype=dt))

def M_simplex_projection(M, y):
    gamma = torch.min(y) - 0.5
    d = y.shape[0]
    dt = y.dtype
    t = 0
    while (t < 10):
        v = y - gamma * torch.ones(d, dtype=dt)
        v_pos = torch.max(v,torch.zeros(d, dtype=dt))
        c = 0 
        for i in range(d): 
            if (0 <= v[i] and v[i] <= 1):
                c += 1
        gamma -= (M - sum(torch.min(v_pos, torch.ones(d, dtype=dt)))) / c
        t += 1
    return torch.min(v_pos, torch.ones(d, dtype=dt))

def gamma_est(func, theta, m, K, N):
    l = []
    for n in range(N):
        temp = []
        for j in range(m):
            for k in range(K):
                temp.append( (func(theta, j, (k + 1) / K, n) - func(theta, j, k / K, n)))
        l.append(temp)
    return np.array(l)

def gamma_zero(func, theta, m, N):
    return np.array([sum([func(theta, j, 0, n) for j in range(m)]) for n in range(N)])

def numerator(theta, j, p, n, lib=0):
    if lib:
        return ((theta[n][j][0]*p + theta[n][j][1]) * torch.exp(theta[n][j][2]*p + theta[n][j][3]))
    return ((theta[n][j][0]*p + theta[n][j][1]) * np.exp(theta[n][j][2]*p + theta[n][j][3]))

def denominator(theta, j, p, n,lib=0):
    if lib:
        return (torch.exp(theta[n][j][2]*p + theta[n][j][3]))
    return (np.exp(theta[n][j][2]*p + theta[n][j][3]))

def FCP_numerator(theta, j, p, n, lib=0):
    # theta -> [theta_0, theta_1, theta_2]
    # theta_0 -> m dim across alternatives : a_j
    # theta_1 -> N dim across clients : b_i
    # theta_2 -> N dim across clients : incumbent utility
    # if lib:
    #     return ((theta[n][j][0]*p + theta[n][j][1]) * torch.exp(theta[n][j][2]*p + theta[n][j][3]))
    # if (lib == 0):
    return (np.exp(theta[0][n][j]*p + theta[1][n][j]))
    # return (torch.exp(theta[0][n][j]*p + theta[1][n][j]))

def FCP_denominator(theta, j, p, n,lib=0):
    # if lib:
    #     return (torch.exp(theta[n][j][2]*p + theta[n][j][3]))
    m = len(theta[0][0])
    # if (lib == 0):
    return np.exp(theta[0][n][j]*p + theta[1][n][j]) + np.exp(theta[2][n]) / m
    # return torch.exp(theta[0][n][j]*p + theta[1][n][j]) + torch.exp(theta[2][n]) / m

def utility(theta, num, den, m, p, N):
    total = 0.0
    total_sq = 0.0
    for n in range(N):
        num_val = 0.0
        den_val = 0.0
        for j in range(m):
            num_val += num(theta, j, p[j], n, 1)
            den_val += den(theta, j, p[j], n, 1)
        frac = num_val / den_val
        total += frac
    return total / N

def utility_and_worst_and_variance(theta, num, den, m, p, N):
    total = 0.0
    total_sq = 0.0
    worst = math.inf
    for n in range(N):
        num_val = 0.0
        den_val = 0.0
        for j in range(m):
            num_val += num(theta, j, p[j], n, 1)
            den_val += den(theta, j, p[j], n, 1)
        frac = num_val / den_val
        if(frac <= worst):
            worst = frac
        total += frac
        total_sq += frac ** 2
    var = sqrt((total_sq - (total ** 2) / N) / N)
    return total / N, worst, var

def utility_and_variance(theta, num, den, m, p, N):
    total = 0.0
    total_sq = 0.0
    for n in range(N):
        num_val = 0.0
        den_val = 0.0
        for j in range(m):
            num_val += num(theta, j, p[j], n, 1)
            den_val += den(theta, j, p[j], n, 1)
        frac = num_val / den_val
        total += frac
        total_sq += frac ** 2
    var = sqrt((total_sq - (total ** 2) / N) / N)
    return total / N, var

def utility_robust(theta, num, den, m, p, N, xi, weights, lib=0, N_act=None):
    total = 0.0
    total_sq = 0.0
    N_tot = 0
    for n in range(N):
        num_val = 0.0
        den_val = 0.0
        for j in range(m):
            num_val += num(theta, j, p[j], n, lib)
            den_val += den(theta, j, p[j], n, lib)
        frac = num_val / den_val
        total += weights[n] * frac
        total_sq += weights[n] * frac ** 2
        N_tot += weights[n]
    if (N_act is None):
        rho = sqrt(2 * xi) / N
    else: 
        rho = sqrt(2 * xi) / N_act
    return total/N_tot - rho * sqrt(total_sq - (total ** 2) / N_tot)
    # return total/N - rho * sqrt(total_sq + (total / N) ** 2 * (N_tot - 2 * n)


def generate_rationality_dirichlet(N, alpha, support=5.0):
    # Support [0, 5]
    rat = support * np.random.dirichlet(alpha, N)
    return rat

def generate_rationality_beta(N, alpha, support=5.0, lower_rat=0.0):
    # Support [low_rat, low_rat + support]
    rat = []
    D = alpha.shape[0]
    for d in range(D):
        rat.append(lower_rat + support * np.random.beta(alpha[d][0], alpha[d][1], N) / D)
    return np.array(rat).T

def generate_rationality_beta_perturbed(N, alpha, support=5.0, eps=0.5):
    # Support [0, 5]
    rat = []
    D = alpha.shape[0]
    perb = 2 * eps * np.random.rand() - eps
    print("Perb : ",perb)
    alpha = (1 + perb) * alpha
    for d in range(D):
        rat.append(support * np.random.beta(alpha[d][0], alpha[d][1], N) / D)
    return np.array(rat).T

def perturb_mu(mu, perb):
    # return mu + perb * (2 * np.random.rand(mu.shape[0]) - 1)
    return mu * (1 + perb * (2 * np.random.rand(mu.shape[0]) - 1))

def generate_params(m, D):

    # general sum game
    W_D = np.random.rand(m, 2)
    W_A = -np.random.rand(m, 2, D)

    # zeroutility sum game
    # W_A_1 = np.random.rand(m)
    # W_A_1 = np.sort(W_A_1)
    # W_A_2 = -np.random.rand(m)
    # W_A_2 = -np.sort(W_A_2)
    # W_A_1 = W_A_1.reshape(m, 1, D)
    # W_A_2 = W_A_2.reshape(m, 1, D)
    # W_A = -np.concatenate([W_A_1 , W_A_2], axis=1)
    # W_D = -np.sum(W_A, axis=2) / D

    scale_d = 10.0
    scale_a = 1.0
    return scale_d * W_D, scale_a * W_A

def generate_params_normal(m, shift=5.0):

    L = np.random.rand(m, m)
    scale_cov = 0.25
    Cov = scale_cov * np.dot(L, L.T)
    scale_mu = 10.0
    mu = scale_mu * np.random.rand(m) + shift
    return mu, Cov

def generate_mu(m, low, high, a, b):
    mu = low + (high - low) * np.random.beta(a, b, m)
    return mu

def generate_theta(W_A, W_D, rat):
    N = rat.shape[0]
    D = rat.shape[1]
    theta_2 = np.array([np.sum(rat[i] * W_A[:, :, :], axis=2) for i in range(N)])
    theta_1 = np.array([W_D for i in range(N)])
    theta = np.concatenate([theta_1, theta_2], axis=2)
    return theta

def generate_theta_normal(mu_0, cov_0, mu_1, cov_1, mu_2, cov_2, mu_3, cov_3, N):
    # (N, m, 2) 
    scale_d = 1
    scale_a = 0.1
    W_A_0 = -np.random.multivariate_normal(mu_0, cov_0, N).reshape(N,-1,1) 
    W_A_1 = -np.random.multivariate_normal(mu_1, cov_1, N).reshape(N,-1,1)
    W_D_0 = np.random.multivariate_normal(mu_2, cov_2, 1).reshape(-1,1)
    W_D_1 = np.random.multivariate_normal(mu_3, cov_3, 1).reshape(-1,1)
    W_D = np.concatenate([W_D_0, W_D_1], axis=1)
    theta_1 = scale_d * np.array([W_D for i in range(N)])
    theta_2 = scale_a * np.concatenate([W_A_0, W_A_1], axis=2)
    # theta = np.concatenate([theta_1, theta_2], axis=2)
    return theta_1, theta_2

def generate_theta_normal_zero_sum(mu_0, cov_0, mu_1, cov_1, N):
    # (N, m, 2) 
    scale_d = 2
    scale_a = 0.1
    W_A_0 = -np.random.multivariate_normal(mu_0, cov_0, N).reshape(N,-1,1) 
    W_A_1 = -np.random.multivariate_normal(mu_1, cov_1, N).reshape(N,-1,1)
    # W_D_0 = np.random.multivariate_normal(mu_0, cov_0, 1).reshape(-1,1)
    # W_D_1 = np.random.multivariate_normal(mu_1, cov_0, 1).reshape(-1,1)
    # W_D = np.concatenate([W_D_0, W_D_1], axis=1)
    # theta_1 = scale_d * np.array([W_D for i in range(N)]) 
    theta_2 = np.concatenate([W_A_0, W_A_1], axis=2)
    theta_1 = -theta_2
    theta_2 = scale_a * theta_2
    theta_1 = scale_d * theta_1
    theta = np.concatenate([theta_1, theta_2], axis=2)
    return theta

def compute_params(m, K, N, numerator, denominator, theta):
    A = gamma_est(numerator, theta, m, K, N)
    b = gamma_zero(numerator, theta, m, N)
    C = gamma_est(denominator, theta, m, K, N)
    d = gamma_zero(denominator, theta, m, N)
    tL = 1.0 / (np.sum(np.maximum(C, np.zeros((N, m*K))), axis=1) + d)
    tU = 1.0 / (np.sum(np.minimum(C, np.zeros((N, m*K))), axis=1) + d)
    return A, b, C, d, tL, tU

def upper_lower_t(V, U):
    N = V.shape[0]
    m = V.shape[1]
    # tL = 1.0 / (np.sum(np.maximum(V, np.zeros((N, m))), axis=1) + U)
    # tU = 1.0 / (np.sum(np.minimum(V, np.zeros((N, m))), axis=1) + U)
    tL = 1.0 / (np.sum(V, axis=1) + U)
    tU = 1.0 / U
    return tL, tU

def FCP_ERM(m, K, N, M, tL, tU, A, b, C, d, weights):
    
    opt = Model()

    z = opt.addMVar(m * K, vtype=GRB.BINARY, name="Z")
    f = opt.addMVar(N, name="F")
    t = opt.addMVar(N, lb=tL, ub=tU, name="t")
    Y = opt.addMVar((N, m * K), name="Y")

    objective = opt.setObjective(sum(f), GRB.MAXIMIZE)
    opt.setParam('NumericFocus', 3)
    opt.setParam('TimeLimit', 180*60)

    lc8 = opt.addConstr(sum(z) - M*K <= 0)
    lc9 = opt.addConstrs((z[j * K + k + 1] - z[j * K + k] <= 0) for j in range(m) for k in range(K-1))

    # l10 = opt.addConstr(t - tL >= 0.0)
    # l11 = opt.addConstr(t - tU <= 0.0)

    lc4 = opt.addConstrs((Y[i] - tU[i] * z <= 0) for i in range(N))
    lc5 = opt.addConstrs((Y[i][j] - (tL[i] * z[j] + t[i] - tL[i] ) <= 0) for i in range(N) for j in range(m * K))
    lc6 = opt.addConstrs((-Y[i][j] + (t[i] + tU[i] * z[j] - tU[i]) <= 0) for i in range(N) for j in range(m * K))
    lc7 = opt.addConstrs((-Y[i]  + tL[i] * z <= 0) for i in range(N))

    lc3 = opt.addConstrs((C[i] @ Y[i] + d[i] * t[i] - 1 == 0) for i in range(N))    
    # lc31 = opt.addConstrs((C[i] @ Y[i] + d[i] * t[i] - (1 + eps) <= 0) for i in range(N))
    # lc32 = opt.addConstrs((C[i] @ Y[i] + d[i] * t[i] - (1 - eps) >= 0) for i in range(N))

    lc2 = opt.addConstrs((f[i] / weights[i] - A[i] @ Y[i] - b[i] * t[i] == 0 for i in range(N)))

    opt.optimize()
    
    Z_opt_grb = np.sum((z.x).reshape(m, K), axis=1)/K

    return Z_opt_grb

def FCP_DRO(m, K, N, N_act, M, tL, tU, A, b, C, d, weights, xi=1.0):

    rho = sqrt(2 * xi)/ N_act
    opt = Model()
    opt.setParam('NumericFocus', 3)
    z = opt.addMVar(m * K, vtype=GRB.BINARY, name="Z")
    l = opt.addMVar(N+1, lb=-float('inf'), name="L")
    # l_aux = opt.addMVar(1, lb=0.0, name="L_aux")
    q = opt.addMVar(1, name="q")
    t = opt.addMVar(N, lb=tL, ub=tU, name="t")
    Y = opt.addMVar((N, m * K), name="Y")


    opt.setParam('TimeLimit', 60*60)

    objective = opt.setObjective(q - rho * l[N], GRB.MAXIMIZE)

    # opt.params.NonConvex = 2
    opt.params.PreMIQCPForm = 1

    lc8 = opt.addConstr(sum(z) - M*K == 0)
    lc9 = opt.addConstrs((z[j * K + k + 1] - z[j * K + k] <= 0) for j in range(m) for k in range(K-1))
    lc_pos = opt.addConstr(l[N] >= 0)
    # l10 = opt.addConstr(t - tL >= 0.0)
    # l11 = opt.addConstr(t - tU <= 0.0)

    lc4 = opt.addConstrs((Y[i] - tU[i] * z <= 0) for i in range(N))
    lc5 = opt.addConstrs((Y[i][j] - (tL[i] * z[j] + t[i] - tL[i] ) <= 0) for i in range(N) for j in range(m * K))
    lc6 = opt.addConstrs((-Y[i][j] + (t[i] + tU[i] * z[j] - tU[i]) <= 0) for i in range(N) for j in range(m * K))
    lc7 = opt.addConstrs((-Y[i]  + tL[i] * z <= 0) for i in range(N))

    lc3 = opt.addConstrs((C[i] @ Y[i] + d[i] * t[i] - 1 == 0) for i in range(N))
    # lc31 = opt.addConstrs((C[i] @ Y[i] + d[i] * t[i] - (1 + eps) <= 0) for i in range(N))
    # lc32 = opt.addConstrs((C[i] @ Y[i] + d[i] * t[i] - (1 - eps) >= 0) for i in range(N))


    lc2 = opt.addConstrs((q - l[i] / sqrt(weights[i]) - A[i] @ Y[i] - b[i] * t[i] == 0 for i in range(N)))

    weights = weights.reshape(1,-1)
    lc1 = opt.addConstr(np.sqrt(weights) @ l == 0.0)
    # lc1 = opt.addConstr(sum(l) == 0.0)

    # aux_c = opt.addQConstr(l_aux @ l_aux - l @ l == 0.0)
    Q = np.eye(N+1)
    Q[-1][-1] = -1
    # xL = opt.addMVar(2)
    # xR = opt.addMVar(3)
    aux_c = opt.addMQConstr(Q, None, '<', rhs=0.0, xQ_L= l, xQ_R=l)
    # aux_c = opt.addQConstr(l_aux @ l_aux - l @ l >= 0.0)

    opt.optimize()
    
    Z_opt_grb = np.sum((z.x).reshape(m, K), axis=1)/K
    print(q.x)
    print(l.x[-1])
    return Z_opt_grb

def FLP_DRO(m, N, N_act, M, tL, tU, A, b, C, d, weights, xi=1.0):

    rho = sqrt(2 * xi)/ N_act
    print(rho)
    opt = Model()
    opt.setParam('TimeLimit', 30*60)
    opt.setParam('NumericFocus', 3)
    X = opt.addMVar(m, vtype=GRB.BINARY, name="X")
    l = opt.addMVar(N+1, lb=-float('inf'), name="L")
    # l_eq = opt.addMVar((N,1), lb=-float('inf'), name="L")
    # u = opt.addMVar(N, name="u")
    # wl = opt.addMVar(N, name="WL")
    l_aux = opt.addMVar(1, lb=0.0, ub=float('inf'), name="L_aux")
    # l_aux = opt.addVar(name="L_aux")
    q = opt.addMVar(1, name="q")
    # t = opt.addMVar(N, lb=tL, ub=tU, name="t")
    t = opt.addMVar(N, name="t")
    Y = opt.addMVar((N,m), name="Y")

    objective = opt.setObjective(q - rho * l[N], GRB.MAXIMIZE)
    # objective = opt.setObjective(q, GRB.MAXIMIZE)

    # opt.params.NonConvex = 2
    opt.params.PreMIQCPForm = 1

    lc8 = opt.addConstr(sum(X) - M == 0)
    lc_pos = opt.addConstr(l[N] >= 0)

    # l10 = opt.addConstr(t - tL >= 0.0)
    # l11 = opt.addConstr(t - tU <= 0.0)

    l10 = opt.addConstrs((t[i] - tL[i] >= 0.0) for i in range(N))
    l11 = opt.addConstrs((t[i] - tU[i] <= 0.0) for i in range(N))

    lc4 = opt.addConstrs((Y[i] - tU[i] * X <= 0) for i in range(N))
    # lc4 = opt.addConstrs((Y[i][j] - tU[i] * X[j] <= 0) for i in range(N) for j in range(m))
    lc5 = opt.addConstrs((Y[i][j] - (tL[i] * X[j] + t[i] - tL[i] ) <= 0) for i in range(N) for j in range(m))
    lc6 = opt.addConstrs((-Y[i][j] + (t[i] + tU[i] * X[j] - tU[i]) <= 0) for i in range(N) for j in range(m))
    # lc7 = opt.addConstrs((-Y[i][j]  + tL[i] * X[j] <= 0) for i in range(N) for j in range(m))
    lc7 = opt.addConstrs((-Y[i]  + tL[i] * X <= 0) for i in range(N))

    lc3 = opt.addConstrs((C[i] @ Y[i] + d[i] * t[i] - 1 == 0) for i in range(N))
    # eps = 0.01
    # lc31 = opt.addConstrs((C[i] @ Y[i] + d[i] * t[i] - (1 + eps) <= 0) for i in range(N))
    # lc32 = opt.addConstrs((C[i] @ Y[i] + d[i] * t[i] - (1 - eps) >= 0) for i in range(N))


    lc2 = opt.addConstrs((q - l[i] / sqrt(weights[i])) - A[i] @ Y[i] == 0 for i in range(N))

    # lc1 = opt.addConstr(sum(l) == 0.0)
    weights = weights.reshape(1,-1)
    lc1 = opt.addConstr(np.sqrt(weights) @ l == 0.0)
    # print("LC1")
    # print(np.sqrt(weights) @ l)
    # lc_new = opt.addConstrs(l[i] == l_eq[i][0] for i in range(N))
    # lc0 = opt.addConstrs(wl[i] == weights[i] * l[i] for i in range(N))

    # print(quicksum([l[i] * l[i] for i in range(m)]))
    # print(np.transpose(l) @ l)
    # print(sum(u))
    print("Dot Product")
    # print(l_eq @ l)
    Q = np.eye(N+1)
    Q[-1][-1] = -1
    # xL = opt.addMVar(2)
    # xR = opt.addMVar(3)
    aux_c = opt.addMQConstr(Q, None, '<', rhs=0.0, xQ_L= l, xQ_R=l)
    # aux_c = opt.addQConstr(l_aux @ l_aux - l_eq @ l >= 0.0)

    opt.optimize()
    # print(opt.display())
    
    X_opt_grb = X.x #np.sum((z.x).reshape(m, K), axis=1)/K
    print(q.x)
    # print(l_aux.x)
    print(l.x[-1])

    return X_opt_grb

def FLP_ERM(m, N, N_act, M, tL, tU, A, b, C, d, weights):

    # rho = sqrt(2 * xi)/ N_act
    opt = Model()
    opt.setParam('TimeLimit', 30*60)
    opt.setParam('NumericFocus', 3)
    X = opt.addMVar(m, vtype=GRB.BINARY, name="X")
    f = opt.addMVar(N, name='f')
    t = opt.addMVar(N, name="t")
    Y = opt.addMVar((N,m), name="Y")
    q = opt.addMVar(1, name="q")
    l = opt.addMVar(N, name='l')

    objective = opt.setObjective(sum(f), GRB.MAXIMIZE)

    # opt.setParam('TimeLimit', 10*60)

    lc8 = opt.addConstr(sum(X) - M <= 0)

    l10 = opt.addConstrs((t[i] - tL[i] >= 0.0) for i in range(N))
    l11 = opt.addConstrs((t[i] - tU[i] <= 0.0) for i in range(N))

    lc4 = opt.addConstrs((Y[i] - tU[i] * X <= 0) for i in range(N))
    # lc4 = opt.addConstrs((Y[i][j] - tU[i] * X[j] <= 0) for i in range(N) for j in range(m))
    lc5 = opt.addConstrs((Y[i][j] - (tL[i] * X[j] + t[i] - tL[i] ) <= 0) for i in range(N) for j in range(m))
    lc6 = opt.addConstrs((-Y[i][j] + (t[i] + tU[i] * X[j] - tU[i]) <= 0) for i in range(N) for j in range(m))
    # lc7 = opt.addConstrs((-Y[i][j]  + tL[i] * X[j] <= 0) for i in range(N) for j in range(m))
    lc7 = opt.addConstrs((-Y[i]  + tL[i] * X <= 0) for i in range(N))


    lc3 = opt.addConstrs((C[i] @ Y[i] + d[i] * t[i] - 1 == 0) for i in range(N))    

    lc2 = opt.addConstrs(f[i] / weights[i] - Y[i] @ A[i] == 0.0 for i in range(N))

    opt.optimize()
    
    X_opt_grb = X.x #np.sum((z.x).reshape(m, K), axis=1)/K

    # return X_opt_grb, t.x, q.x, l.x
    return X.x

def FLP_values(X, V, U):
    N = V.shape[0]
    m = V.shape[1]
    values = []
    for i in range(N):
        ut1 = X @ V[i]
        values.append(ut1 / (ut1 + U[i]))
    return values

def FLP_utility(p, theta, weights):
    total = 0.0
    total_w = 0.0
    U = theta[0]
    V = torch.tensor(theta[1], dtype=p.dtype)
    N = V.shape[0]
    m = V.shape[1]
    for n in range(N):
        # ut1 = p @ V[n]
        ut1 = torch.dot(p, V[n])
        frac = ut1 / (ut1 + U[n])
        total += weights[n] * frac
        total_w += weights[n]
    return total / total_w

def FLP_utility_robust(p, theta, weights,xi):
    total = 0.0
    total_w = 0.0
    total_n = 0.0
    total_sq = 0.0
    U = theta[0]
    V = torch.tensor(theta[1], dtype=p.dtype)
    N = V.shape[0]
    m = V.shape[1]
    for n in range(N):
        ut1 = torch.dot(p, V[n])
        frac = ut1 / (ut1 + U[n])
        total += weights[n] * frac
        total_sq += weights[n] * frac ** 2
        total_w += weights[n]
        total_n += 1.0
    return total/total_w - sqrt(2 * xi / total_n ** 2) *sqrt(total_sq - (total ** 2) / total_w)

def FCP_utility(theta, m, p, N, weights):
    total = 0.0
    total_w = 0.0
    for n in range(N):
        num_val = 0.0
        den_val = 0.0
        fix = None
        for j in range(m):
            if(fix is None):
                fix = theta[0][n][j]*p[j] + theta[1][n][j]
            elif(theta[0][n][j]*p[j] + theta[1][n][j] > fix):
                fix = theta[0][n][j]*p[j] + theta[1][n][j]
        # print(fix)
        for j in range(m):
            num_val += FCP_numerator_fix(theta, j, p[j], n, -fix)
            den_val += FCP_denominator_fix(theta, j, p[j], n, -fix)
            # print(num_val)
            # print(den_val)
        frac = weights[n] * num_val / den_val
        total += frac
        total_w += weights[n]
    return total / total_w

def FCP_numerator_fix(theta, j, p, n, fix):
    # theta -> [theta_0, theta_1, theta_2]
    # theta_0 -> m dim across alternatives : a_j
    # theta_1 -> N dim across clients : b_i
    # theta_2 -> N dim across clients : incumbent utility
    # if lib:
    #     return ((theta[n][j][0]*p + theta[n][j][1]) * torch.exp(theta[n][j][2]*p + theta[n][j][3]))
    # return (np.exp(theta[0][n][j]*p + theta[1][n][j]))
    return (torch.exp(theta[0][n][j]*p + theta[1][n][j] + fix))

def FCP_denominator_fix(theta, j, p, n,fix):
    # if lib:
    #     return (torch.exp(theta[n][j][2]*p + theta[n][j][3]))
    m = len(theta[0][0])
    # return np.exp(theta[0][n][j]*p + theta[1][n][j]) + np.exp(theta[2][n]) / m
    return torch.exp(theta[0][n][j]*p + theta[1][n][j] + fix) + torch.exp(theta[2][n] + fix) / m


def FCP_values(Z, theta):
    m = len(theta[0][0])
    N = len(theta[1])
    # print(Z)
    # print([FCP_numerator(theta, j, Z[j], 0) for j in range(m)])
    # print(sum([FCP_denominator(theta, j, Z[j], 0) for j in range(m)][0]))
    return np.array([sum([FCP_numerator(theta, j, Z[j], n) for j in range(m)]) / sum([FCP_denominator(theta, j, Z[j], n) for j in range(m)]) for n in range(N)])

def weighted_FLP_values(X, S, V, U):
    N = V.shape[0]
    m = V.shape[1]
    values = []
    for i in range(N):
        ut1 = X @ V[i]
        values.append(S[i] * ut1 / (ut1 + U[i]))
    return values

def weighted_avg_and_std(values, weights):
    """
    Return the weighted average and standard deviation.

    values, weights -- Numpy ndarrays with the same shape.
    """
    average = np.average(values, weights=weights)
    # Fast and numerically precise:
    variance = np.average((values-average)**2, weights=weights)
    return (average, math.sqrt(variance))


def gradient_descent(p_i, N, theta, m, rho=None, num_epochs=1000, print_every=100, lr=0.05):
    if(rho is None):
        print("GD on ERM gain")
        best_u = utility(theta, numerator, denominator, m, torch.tensor(p_i), N).item()
    else:
        print("GD on DRO gain")
        best_u = utility_robust(theta, numerator, denominator, m, torch.tensor(p_i), N, rho).item()
    best_p = p_i
    p = torch.tensor(p_i, requires_grad = True)

    for i in range(num_epochs):
        if(rho is None):
            u = utility(theta, numerator, denominator, m, p, N)
        else:
            u = utility_robust(theta, numerator, denominator, m, p, N, rho)
        # u = utility(theta, numerator, denominator, m, p, N)
        if(u.item() > best_u):
            best_u = u.item()
            best_p = p.detach()
        if(i == 0):
            print("Initial : ", u.item())
        u.backward()
        p = torch.tensor(simplex_projection((p + lr * p.grad).detach()), requires_grad=True)
        if((i + 1) % print_every == 0):
            print("Epoch : ", (i + 1), " Best : ", best_u)
    return best_p

def gradient_descent_minimize(p_i, N, theta, m, rho=None, num_epochs=1000, print_every=100, lr=0.05):
    if(rho is None):
        print("GD on ERM gain")
        best_u = utility(theta, numerator, denominator, m, torch.tensor(p_i), N).item()
    else:
        print("GD on DRO gain")
        best_u = utility_robust(theta, numerator, denominator, m, torch.tensor(p_i), N, rho).item()
    best_p = p_i
    p = torch.tensor(p_i, requires_grad = True)

    for i in range(num_epochs):
        if(rho is None):
            u = utility(theta, numerator, denominator, m, p, N)
        else:
            u = utility_robust(theta, numerator, denominator, m, p, N, rho)
        # u = utility(theta, numerator, denominator, m, p, N)
        if(u.item() > best_u):
            best_u = u.item()
            best_p = p.detach()
        if(i == 0):
            print("Initial : ", u.item())
        u.backward()
        p = torch.tensor(simplex_projection((p - lr * p.grad).detach()), requires_grad=True)
        if((i + 1) % print_every == 0):
            print("Epoch : ", (i + 1), " Best : ", best_u)
    return best_p


def gradient_descent_weighted(p_i, N, theta, m, weights, rho=None, num_epochs=1000, print_every=100, lr=0.05):
    if(rho is None):
        print("GD on ERM gain")
        best_u = utility(theta, numerator, denominator, m, torch.tensor(p_i), N).item()
    else:
        print("GD on DRO gain")
        best_u = utility_robust_weighted(theta, numerator, denominator, m, torch.tensor(p_i), N, rho, weights).item()
    best_p = p_i
    p = torch.tensor(p_i, requires_grad = True)

    for i in range(num_epochs):
        if(rho is None):
            u = utility(theta, numerator, denominator, m, p, N)
        else:
            u = utility_robust_weighted(theta, numerator, denominator, m, p, N, rho, weights)
        # u = utility(theta, numerator, denominator, m, p, N)
        if(u.item() > best_u):
            best_u = u.item()
            best_p = p.detach()
        if(i == 0):
            print("Initial : ", u.item())
        u.backward()
        p = torch.tensor(simplex_projection((p + lr * p.grad).detach()), requires_grad=True)
        if((i + 1) % print_every == 0):
            print("Epoch : ", (i + 1), " Best : ", best_u)
    return best_p

def FLP_gradient_descent(p_i, N, U, V, m, M, weights, batch_size, xi=None, num_epochs=1000, lr=0.05):
    p = torch.tensor(p_i, requires_grad = True)
    for i in range(num_epochs):
        arr = np.arange(N)
        np.random.shuffle(arr)
        for idx in range(batch_size, N, batch_size):
            U_curr = U[arr[idx-batch_size:idx]]
            V_curr = V[arr[idx-batch_size:idx]]
            weights_curr = weights[arr[idx-batch_size:idx]]
            theta = [U_curr, V_curr]
            if(xi is None):
                # u = utility(theta, numerator, denominator, m, p, batch_size)
                u = FLP_utility(p, theta, weights_curr)
            else:
                u = FLP_utility_robust(p, theta, weights, xi)
            # if(i == 0):
            #     print("Initial : ", u.item())
            u.backward()
            p = torch.tensor(M_simplex_projection(M, (p + lr * p.grad).detach()), requires_grad=True)
        # if((i + 1) % print_every == 0):
        if(xi is None):
            # u = utility(full_theta, numerator, denominator, m, p, N)
            u = FLP_utility(p, [U, V], weights)
        else:
            u = FLP_utility_robust(p, [U, V], weights, xi)
        print("Epoch : ", (i + 1), " U : ", u.item())
    return p.detach()

def FCP_gradient_descent(p_i, N, full_theta, m, M, weights, batch_size, xi=None, num_epochs=1000, lr=0.05):
    p = torch.tensor(p_i, requires_grad = True)
    # u = FCP_utility(full_theta, m, p, N, weights)
    # print(u.item())
    for i in range(num_epochs):
        arr = np.arange(N)
        np.random.shuffle(arr)
        for idx in range(batch_size, N, batch_size):
            # print(idx)
            theta = [full_theta[0][arr[idx-batch_size:idx]], full_theta[1][arr[idx-batch_size:idx]], full_theta[2][arr[idx-batch_size:idx]]] 
            weights_curr = weights[arr[idx-batch_size:idx]]
            if(xi is None):
                u = FCP_utility(theta, m, p, batch_size, weights_curr)
                # u = FLP_utility(p, theta, weights_curr)
            else:
                u = FLP_utility_robust(p, theta, weights, xi)
            print(u.item())
            # if(i == 0):
            #     print("Initial : ", u.item())
            u.backward()
            p = torch.tensor(M_simplex_projection(M, (p + lr * p.grad).detach()), requires_grad=True)
        # if((i + 1) % print_every == 0):
        if(xi is None):
            # u = utility(full_theta, numerator, denominator, m, p, N)
            u = FCP_utility(full_theta, m, p, N, weights)
        else:
            u = FLP_utility_robust(p, [U, V], weights, xi)
        print("Epoch : ", (i + 1), " U : ", u.item())
    return p.detach()

def SSG_gradient_descent(p_i, N, num, den, full_theta, m, M, weights, batch_size, xi=None, num_epochs=1000, lr=0.05):
    p = torch.tensor(p_i, requires_grad = True)
    u = utility_robust(full_theta, num, den, m, p, N,xi, weights, 1)
    while(math.isnan(u.item())):
        p_i = M_simplex_projection(M, torch.rand(m))
        p = torch.tensor(p_i, requires_grad = True)
        u = utility_robust(full_theta, num, den, m, p, N,xi, weights, 1)
    # u = FCP_utility(full_theta, m, p, N, weights)
    # print(u.item())
    for i in range(num_epochs):
        arr = np.arange(N)
        np.random.shuffle(arr)
        for idx in range(0, N-batch_size+1, batch_size):
            # print(idx)
            # theta = [full_theta[0][arr[idx-batch_size:idx]], full_theta[1][arr[idx-batch_size:idx]], full_theta[2][arr[idx-batch_size:idx]]] 
            theta = full_theta[arr[idx:batch_size+idx]]
            # print(theta.shape)
            weights_curr = weights[arr[idx:batch_size+idx]]
            if(xi is None):
                u = FCP_utility(theta, m, p, batch_size, weights_curr)
                # u = FLP_utility(p, theta, weights_curr)
            else:
                # u = FLP_utility_robust(p, theta, weights, xi)
                u = utility_robust(theta, num, den, m, p, batch_size,xi, weights_curr, 1, N)
            # print(u.item())
            # if(i == 0):
            #     print("Initial : ", u.item())
            u.backward()
            p = torch.tensor(M_simplex_projection(M, (p + lr * p.grad).detach()), requires_grad=True)
        # if((i + 1) % print_every == 0):
        if(xi is None):
            # u = utility(full_theta, numerator, denominator, m, p, N)
            u = FCP_utility(full_theta, m, p, N, weights)
        else:
            # u = FLP_utility_robust(p, [U, V], weights, xi)
            u = utility_robust(full_theta, num, den, m, p, N,xi, weights, 1)
        print("Epoch : ", (i + 1), " U : ", u.item())
    return p.detach()

def maxmin_utility(theta, num, den, m, p, N, w):
    total = 0.0
    total_sq = 0.0
    for n in range(N):
        num_val = 0.0
        den_val = 0.0
        for j in range(m):
            num_val += num(theta, j, p[j], n, 1)
            den_val += den(theta, j, p[j], n, 1)
        frac = w[n] * num_val / den_val
        total += frac
    return total

def proj_P(P_0, xi=1e6):
    N = P_0.shape[0]
    eps = 2 * xi / (N ** 2)
    fun = lambda x: np.sum((x - P_0)**2)
    cons = lambda x: np.sum((x - np.ones(N)/N)**2)
    P_0 = list(P_0)
    linear_constraint = LinearConstraint(torch.cat([torch.ones(1, N), torch.eye(N)]).numpy(), torch.cat([torch.ones(1), torch.zeros(N)]).numpy()
                                         , torch.ones(N+1).numpy())
    non_linear_constraint = NonlinearConstraint(cons, 0, eps)
    P = minimize(fun, P_0, method='trust-constr',constraints=[linear_constraint, non_linear_constraint])
    return P.x

def SSG_TTGD(p_i, w_i, N, num, den, full_theta, m, M, xi, num_epochs=1000, lr_p=0.001, lr_w=0.001):
    p = torch.tensor(p_i, requires_grad = True)
    w = torch.tensor(w_i, requires_grad = True)
    u = maxmin_utility(full_theta, num, den, m, p, N, w)
    while(math.isnan(u.item())):
        p_i = M_simplex_projection(M, torch.rand(m))
        p = torch.tensor(p_i, requires_grad = True)
        u = maxmin_utility(full_theta, num, den, m, p, N, w)
    # u = FCP_utility(full_theta, m, p, N, weights)
    # print(u.item())
    for i in range(num_epochs):
        # Ascent
        u = maxmin_utility(full_theta, num, den, m, p, N, w)
        print("Epoch : ", (i + 1), " U : ", u.item())
        u.backward()
        p = torch.tensor(M_simplex_projection(M, (p + lr_p * p.grad).detach()), requires_grad=True)
        # Descent
        u = maxmin_utility(full_theta, num, den, m, p, N, w)
        u.backward()
        # w = torch.tensor(M_simplex_projection(1, (w - lr_w * w.grad).detach()), requires_grad=True)
        w = torch.tensor(proj_P((w - lr_w * w.grad).detach(), xi), requires_grad=True)
    return p.detach()
